{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1E_HhLEeYqFG"
   },
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "<br>汉化的库: <a href=\"https://github.com/GoatCsu/CN-LLMs-from-scratch.git\">https://github.com/GoatCsu/CN-LLMs-from-scratch.git</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZuWudYFWYiH7"
   },
   "source": [
    "# **高效加载模型权重（Memory-efficient Model Weight Loading）**  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qt0Qyg6ewUt6"
   },
   "source": [
    "- 本笔记本提供了一些 **在 GPU（或 CPU）内存受限时加载大型预训练或微调模型的技巧**。  \n",
    "- 具体来说，它重点介绍了 **如何加载使用 `torch.save(model.state_dict(), \"model.pth\")` 保存的模型**（例如在 **第 5-7 章** 中），以便在新会话中继续 **预训练（Pretraining）或额外微调（Finetuning）**。  \n",
    "- **尽管示例使用的是 LLM**，但本笔记本介绍的方法 **适用于任何 PyTorch 模型的加载**，不仅限于 LLM。  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SxQzFoS-IXdY",
    "outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "memory_profiler version: 0.61.0\n",
      "torch version: 2.4.1+cu121\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "pkgs = [\n",
    "    \"torch\",\n",
    "]\n",
    "for p in pkgs:\n",
    "    print(f\"{p} version: {version(p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y47iQaQKyHap"
   },
   "source": [
    "&nbsp;\n",
    "## 1. **基准测试工具（Benchmark Utilities）**  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nQeOEoo6yT0X"
   },
   "source": [
    "- 首先，我们定义一些 **用于追踪 VRAM（GPU 内存）** 的工具函数。  \n",
    "- 之后，我们还将 **引入一个工具来监测主系统 RAM（CPU 内存）**。  \n",
    "- 这些函数的作用将在后续应用时变得更加清晰。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "pEiqjYrVivgt"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import time\n",
    "import torch\n",
    "\n",
    "\n",
    "def start_memory_tracking():\n",
    "    \"\"\"Initialize GPU memory tracking.\"\"\"\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.reset_peak_memory_stats()\n",
    "    else:\n",
    "        print(\"This notebook is intended for CUDA GPUs but CUDA is not available.\")\n",
    "\n",
    "def print_memory_usage():\n",
    "    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB\n",
    "    print(f\"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB\")\n",
    "\n",
    "def cleanup():\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()\n",
    "    time.sleep(3)  # some buffer time to allow memory to clear\n",
    "    torch.cuda.reset_peak_memory_stats()\n",
    "    max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)\n",
    "    print(f\"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z5oJwoc-kkXs"
   },
   "source": [
    "&nbsp;\n",
    "## 2. 模型设置"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YfJE0vnMyr88"
   },
   "source": [
    "- 该代码部分 **用于初始化模型**。  \n",
    "- 这里，我们使用 **\"GPT-2 Large\" 模型** 以增加实验的挑战性（如果希望减少 **内存占用** 和 **执行时间**，可以选择 **\"GPT-2 Small (124M)\"**）。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "tMuhCYaVI0w7"
   },
   "outputs": [],
   "source": [
    "from previous_chapters import GPTModel\n",
    "\n",
    "\n",
    "BASE_CONFIG = {\n",
    "    \"vocab_size\": 50257,     # Vocabulary size\n",
    "    \"context_length\": 1024,  # Context length\n",
    "    \"drop_rate\": 0.0,        # Dropout rate\n",
    "    \"qkv_bias\": True         # Query-key-value bias\n",
    "}\n",
    "\n",
    "model_configs = {\n",
    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
    "}\n",
    "\n",
    "CHOOSE_MODEL = \"gpt2-xl (1558M)\"\n",
    "\n",
    "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KWYoo1z5y8aX"
   },
   "source": [
    "- 现在，我们 **测试 GPU 内存监测函数的运行情况**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GK3NEA3eJv3f",
    "outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 6.4 GB\n"
     ]
    }
   ],
   "source": [
    "start_memory_tracking()\n",
    "\n",
    "\n",
    "model = GPTModel(BASE_CONFIG)\n",
    "device = torch.device(\"cuda\")\n",
    "model.to(device)\n",
    "\n",
    "print_memory_usage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GIhwBEBxzBsF"
   },
   "source": [
    "- 此外，我们通过 **输入示例张量（tensor）** 来确保 **模型能够正常运行**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "i_j6nZruUd7g"
   },
   "outputs": [],
   "source": [
    "# Test if the model works (no need to track memory here)\n",
    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    model(test_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UgNb8c32zh4g"
   },
   "source": [
    "- 接下来，假设我们 **完成了模型的预训练，并希望将其保存以便后续使用**。  \n",
    "- **（为简洁起见，这里跳过实际的预训练过程，仅保存初始化的模型，但概念相同。）** "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "wUIXjcsimXU7"
   },
   "outputs": [],
   "source": [
    "# Training code would go here...\n",
    "\n",
    "model.train()\n",
    "torch.save(model.state_dict(), \"model.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s9tBS4HUzz1g"
   },
   "source": [
    "- 最后，我们 **在 Python 会话中删除模型和示例张量，以释放 GPU 内存**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SqmTzztqKnTs",
    "outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 0.0 GB\n"
     ]
    }
   ],
   "source": [
    "del model, test_input\n",
    "cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7EnO8beUJ6Sb"
   },
   "source": [
    "&nbsp;\n",
    "## 3. 加载权重"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtAXKjsG0AVL"
   },
   "source": [
    "- **接下来进入关键部分**，我们将 **加载预训练模型的权重**。  \n",
    "- 让我们 **检查加载之前保存的模型所需的 GPU 内存**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wCrQNbSJJO9w",
    "outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 12.8 GB\n"
     ]
    }
   ],
   "source": [
    "# Then load pretrained weights\n",
    "\n",
    "start_memory_tracking()\n",
    "\n",
    "model = GPTModel(BASE_CONFIG)\n",
    "model.to(device)\n",
    "\n",
    "model.load_state_dict(\n",
    "    torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
    ")\n",
    "model.to(device)\n",
    "model.eval();\n",
    "\n",
    "print_memory_usage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4AGvOrcN0KdJ"
   },
   "source": [
    "- **注意**，当前 **内存占用是前一阶段的 2 倍**。  \n",
    "- 这是因为，在加载模型权重时，**短暂地在内存中存储了两份模型**：\n",
    "  - **第一次**：通过 `model.to(device)` 将模型移动到设备（GPU/CPU）。  \n",
    "  - **第二次**：调用  \n",
    "    ```python\n",
    "    model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n",
    "    ```  \n",
    "    **在这一步，模型权重会被加载到 `state_dict` 中，然后复制到模型本体**。但在 **复制完成前**，**内存中同时存在完整的模型和加载的 `state_dict`**，导致占用翻倍。  \n",
    "- **接下来的章节将介绍如何优化这一过程，以减少内存占用**。  \n",
    "- 在此之前，我们先 **测试模型是否正确加载，并重置 GPU 内存**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DvlUn-nmmbuj",
    "outputId": "11d3ab68-f570-4c1e-c631-fe5547026799"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 0.0 GB\n"
     ]
    }
   ],
   "source": [
    "# Test if the model works (no need to track memory here)\n",
    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    model(test_input)\n",
    "\n",
    "del model, test_input\n",
    "cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RdPnW3iLLrjX"
   },
   "source": [
    "&nbsp;\n",
    "## **4. 按顺序加载权重（Loading Weights Sequentially）**  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FYqtUON602TD"
   },
   "source": [
    "- **为了解决上一节提到的“模型权重在 GPU 内存中出现两次”的问题**，可以采用 **按顺序加载（sequential loading）** 的方法。  \n",
    "- 具体来说，我们 **分步加载模型**：\n",
    "  1. **首先**，将 **模型架构** 加载到 **GPU 内存**。  \n",
    "  2. **然后**，将 **模型权重** 先加载到 **CPU 内存**。  \n",
    "  3. **最后**，逐个 **参数** 复制到 **GPU 内存**，避免一次性加载导致的内存峰值。  \n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DOIGTNWTmx9G",
    "outputId": "145162e6-aaa6-4c2a-ed8f-f1cf068adb80"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 6.4 GB\n",
      "Maximum GPU memory allocated: 6.7 GB\n"
     ]
    }
   ],
   "source": [
    "start_memory_tracking()\n",
    "\n",
    "model = GPTModel(BASE_CONFIG).to(device)\n",
    "\n",
    "state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
    "\n",
    "print_memory_usage()\n",
    "\n",
    "# Sequentially copy weights to the model's parameters\n",
    "with torch.no_grad():\n",
    "    for name, param in model.named_parameters():\n",
    "        if name in state_dict:\n",
    "            param.copy_(state_dict[name].to(device))\n",
    "        else:\n",
    "            print(f\"Warning: {name} not found in state_dict.\")\n",
    "\n",
    "print_memory_usage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pn9xD_xL1ZzM"
   },
   "source": [
    "- 如上所示，**采用按序加载方法后，内存占用明显降低**。  \n",
    "- 需要注意的是，**内存从 6.4GB 增加到 6.7GB**，原因如下：  \n",
    "  - **最初**，仅模型结构被加载到 **GPU 内存**。  \n",
    "  - **随后**，每次加载一个参数张量（parameter tensor）并移动到 GPU，以便使用 `\".to()\"` 方法将其赋值到模型中。  \n",
    "  - **整个过程中，我们只在 GPU 中存储一个额外的参数张量**，避免了大规模的额外占用。  \n",
    "- **总体来看，这种方法显著降低了 GPU 内存峰值**。  \n",
    "- 接下来，我们 **简单测试模型的正确性**，然后 **重置 GPU 内存，为下一节做准备**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PRHnjA48nJgw",
    "outputId": "dcd6b1b2-538f-4862-96a6-a5fcbf3326a4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 0.0 GB\n"
     ]
    }
   ],
   "source": [
    "# Test if the model works (no need to track memory here)\n",
    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    model(test_input)\n",
    "\n",
    "del model, test_input, state_dict, param\n",
    "cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5M92LK7usb-Z"
   },
   "source": [
    "&nbsp;\n",
    "## **5. 在低 CPU 内存环境中加载模型（Loading the Model with Low CPU Memory）**  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R45qgeB613e2"
   },
   "source": [
    "- 在上一节中，我们通过 **先将权重 (`state_dict`) 加载到 CPU 内存，再逐个复制到 GPU**，成功降低了 **GPU 内存占用**。  \n",
    "- 但是，如果 **CPU 内存也有限**，我们该如何加载模型？  \n",
    "- 本节将介绍 **PyTorch 的 `\"meta\"` 设备（Meta Device）方法**，该方法适用于 **GPU 内存充足但 CPU 内存较小的设备**。  \n",
    "- 在此之前，我们先 **定义一个工具函数，用于监测 CPU 内存使用情况**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "BrcWy0q-3Bbe"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import psutil\n",
    "from threading import Thread\n",
    "\n",
    "\n",
    "def memory_usage_in_gb(func, *args, **kwargs):\n",
    "    process = psutil.Process(os.getpid())\n",
    "\n",
    "    # Measure the baseline memory usage before running the function\n",
    "    baseline_mem = process.memory_info().rss / 1024 ** 3  # in GB\n",
    "\n",
    "    # Start monitoring memory in a separate thread\n",
    "    mem_usage = []\n",
    "    done = False\n",
    "\n",
    "    def monitor_memory():\n",
    "        while not done:\n",
    "            mem_usage.append(process.memory_info().rss / 1024 ** 3)  # Convert to GB\n",
    "            time.sleep(0.1)\n",
    "\n",
    "    t = Thread(target=monitor_memory)\n",
    "    t.start()\n",
    "\n",
    "    # Run the function\n",
    "    func(*args, **kwargs)\n",
    "\n",
    "    # Stop monitoring\n",
    "    done = True\n",
    "    t.join()\n",
    "\n",
    "    peak_mem_usage_gb = max(mem_usage) - baseline_mem\n",
    "    return peak_mem_usage_gb\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ayy30Ytd5hjF"
   },
   "source": [
    "- **首先，我们追踪上一节** 使用 **顺序加载权重（Sequential Weight Loading）** 方法时的 **CPU 内存占用情况**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "rCkV6IbQtpVn",
    "outputId": "26c0435a-1e3d-4e8f-fbe2-f9655bad61b4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 6.4 GB\n",
      "Maximum GPU memory allocated: 6.7 GB\n",
      "-> Maximum CPU memory allocated: 6.3 GB\n"
     ]
    }
   ],
   "source": [
    "def load_sequentially():\n",
    "    start_memory_tracking()\n",
    "\n",
    "    model = GPTModel(BASE_CONFIG).to(device)\n",
    "\n",
    "    state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "    # Sequentially copy weights to the model's parameters\n",
    "    with torch.no_grad():\n",
    "        for name, param in model.named_parameters():\n",
    "            if name in state_dict:\n",
    "                param.copy_(state_dict[name].to(device))\n",
    "            else:\n",
    "                print(f\"Warning: {name} not found in state_dict.\")\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "\n",
    "peak_memory_used = memory_usage_in_gb(load_sequentially)\n",
    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UWrmnCML5oKy"
   },
   "source": [
    "- **假设我们使用的设备 CPU 内存较小，但 GPU 内存充足**。  \n",
    "- 我们可以 **利用 PyTorch 的 `\"meta\"` 设备**，在 **CPU 和 GPU 内存占用之间进行权衡**。  \n",
    "- **PyTorch 的 `\"meta\"` 设备** 是一种特殊的设备类型，它允许创建 **不实际分配内存** 的张量，即 **\"meta\" 张量**。  \n",
    "- 这对于 **模型分析（Model Analysis）或架构定义（Architecture Definition）** 等任务非常有用，在这些任务中，我们只需要 **张量的形状和数据类型**，而不需要 **真正的内存分配**。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PBErC_5Yt8ly",
    "outputId": "8799db06-191c-47c4-92fa-fbb95d685aa9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 12.8 GB\n",
      "Maximum GPU memory allocated: 12.8 GB\n",
      "-> Maximum CPU memory allocated: 1.3 GB\n"
     ]
    }
   ],
   "source": [
    "def load_sequentially_with_meta():\n",
    "    start_memory_tracking()\n",
    "\n",
    "    with torch.device(\"meta\"):\n",
    "        model = GPTModel(BASE_CONFIG)\n",
    "\n",
    "    model = model.to_empty(device=device)\n",
    "\n",
    "    state_dict = torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "    # Sequentially copy weights to the model's parameters\n",
    "    with torch.no_grad():\n",
    "        for name, param in model.named_parameters():\n",
    "            if name in state_dict:\n",
    "                param.copy_(state_dict[name])\n",
    "            else:\n",
    "                print(f\"Warning: {name} not found in state_dict.\")\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)\n",
    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VpnCABp75-VQ"
   },
   "source": [
    "- 如上所示，通过 **在 `\"meta\"` 设备上创建模型，并直接将权重加载到 GPU 内存**，我们 **显著降低了 CPU 内存占用**。  \n",
    "- 这可能引发一个问题：**“那么，顺序加载权重（Sequential Weight Loading）是否仍然有必要？它与原始方法相比效果如何？”**  \n",
    "- 接下来，我们对比 **PyTorch 的标准权重加载方法**（即本笔记本开头的 **初始权重加载方式**），看看它的 CPU 和 GPU 内存占用情况。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4f-bqBNRuR39",
    "outputId": "f7c0a901-b404-433a-9b93-2bbfa8183c56"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 12.8 GB\n",
      "-> Maximum CPU memory allocated: 4.4 GB\n"
     ]
    }
   ],
   "source": [
    "def baseline():\n",
    "    start_memory_tracking()\n",
    "\n",
    "    model = GPTModel(BASE_CONFIG)\n",
    "    model.to(device)\n",
    "\n",
    "    model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n",
    "    model.to(device)\n",
    "    model.eval();\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "peak_memory_used = memory_usage_in_gb(baseline)\n",
    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NKAjxbX86xnb"
   },
   "source": [
    "- 如上所示，**不使用 `\"meta\"` 设备的“简单”权重加载方式，占用了更多的内存**。  \n",
    "- 换句话说，**如果设备的 CPU 内存有限**，可以采用 **`\"meta\"` 设备方法**，直接将模型权重加载到 **GPU 内存**，从而 **降低 CPU 内存峰值占用**。  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "## **6. 使用 `mmap=True`（推荐方法）**  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **如果你是 PyTorch `torch.load` 的中高级用户**，可能会想知道这些方法与 **`mmap=True` 选项** 有何区别。  \n",
    "- **`mmap=True`** 选项 **启用了内存映射文件 I/O（Memory-Mapped File I/O）**，使得张量可以 **直接从磁盘访问数据**，而不需要将整个文件加载到 RAM 中，从而在 **RAM 受限时显著降低内存占用**。  \n",
    "- 另请参考 **[mikaylagawarecki 的有用评论](https://github.com/rasbt/LLMs-from-scratch/issues/402)**。  \n",
    "- 乍一看，`mmap=True` **可能比前面介绍的顺序加载方法效率更低**：  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "GKwV0AMNemuR",
    "outputId": "e207f2bf-5c87-498e-80fe-e8c4016ac711"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 6.4 GB\n",
      "-> Maximum CPU memory allocated: 5.9 GB\n"
     ]
    }
   ],
   "source": [
    "def best_practices():\n",
    "  with torch.device(\"meta\"):\n",
    "      model = GPTModel(BASE_CONFIG)\n",
    "\n",
    "  model.load_state_dict(\n",
    "      torch.load(\"model.pth\", map_location=device, weights_only=True, mmap=True),\n",
    "      assign=True\n",
    "  )\n",
    "\n",
    "  print_memory_usage()\n",
    "\n",
    "peak_memory_used = memory_usage_in_gb(best_practices)\n",
    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **当前 CPU RAM 占用较高的原因**，是因为 **该设备的 CPU 内存充足**，因此 PyTorch 默认将模型加载到 RAM 中。  \n",
    "- 但是，**如果设备的 CPU RAM 受限**，`mmap` 方法 **会显著降低内存使用量**，因为它允许张量 **直接从磁盘访问数据，而无需将整个文件加载到 RAM**。  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "## 7. 其他方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 本笔记本主要介绍 **PyTorch 内置的简单方法**，用于高效加载模型权重。  \n",
    "- **在 CPU 内存受限的情况下**，推荐使用 **`mmap=True` 方法**（前文已详细介绍）。  \n",
    "- 另外，还有一种 **“暴力”方法**：即 **将每个权重张量（Tensor）分别存储和加载**，以减少内存占用。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "2CgPEZUIb00w"
   },
   "outputs": [],
   "source": [
    "model = GPTModel(BASE_CONFIG)\n",
    "# Assume `model` is your trained model\n",
    "state_dict = model.state_dict()\n",
    "\n",
    "# Create a directory to store individual parameter files\n",
    "os.makedirs(\"model_parameters\", exist_ok=True)\n",
    "\n",
    "# Save each parameter tensor separately\n",
    "for name, param in state_dict.items():\n",
    "    torch.save(param.cpu(), f\"model_parameters/{name}.pt\")\n",
    "\n",
    "del model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "gTsmtJK-b4yy",
    "outputId": "d361e2d3-e34c-48d7-9047-846c9bfd291e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum GPU memory allocated: 6.4 GB\n",
      "Maximum GPU memory allocated: 6.4 GB\n",
      "-> Maximum CPU memory allocated: 0.3 GB\n"
     ]
    }
   ],
   "source": [
    "def load_individual_weights():\n",
    "\n",
    "    start_memory_tracking()\n",
    "\n",
    "    with torch.device(\"meta\"):\n",
    "        model = GPTModel(BASE_CONFIG)\n",
    "\n",
    "    model = model.to_empty(device=device)\n",
    "\n",
    "    print_memory_usage()\n",
    "    param_dir = \"model_parameters\"\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for name, param in model.named_parameters():\n",
    "            weight_path = os.path.join(param_dir, f\"{name}.pt\")\n",
    "            if os.path.exists(weight_path):\n",
    "                param_data = torch.load(weight_path, map_location=\"cpu\", weights_only=True)\n",
    "                param.copy_(param_data)\n",
    "                del param_data  # Free memory\n",
    "            else:\n",
    "                print(f\"Warning: {name} not found in {param_dir}.\")\n",
    "\n",
    "    print_memory_usage()\n",
    "\n",
    "\n",
    "peak_memory_used = memory_usage_in_gb(load_individual_weights)\n",
    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "L4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
